# adopted from https://github.com/seungeunrho/minimalRL

import gym
import collections
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from synthetic_env import synthetic_env

import argparse

# parse arguments from command line
parser = argparse.ArgumentParser(description='Synthetic Succesor Feature Deep Q-learning')
parser.add_argument('--seed', default=0, type=int, help='seed')
parser.add_argument('--c', default=0.01, type=float, help='c')
parser.add_argument('--gamma', default=0.95, type=float, help='gamma')

args = parser.parse_args()

# simulation paraeters
seed = args.seed
torch.manual_seed(seed)
random.seed(seed)
print_interval = 10
num_epi=1000
phi_train_num_epi=1000
skip_phi_train=True

# agent hyperparameters

dqn_lr = 1e-1 #1e-2 #1e-3 #1e-6
gamma         = args.gamma
buffer_limit  = 50000
batch_size    = 32
use_gpi = False
zero_shot = False
if zero_shot:
    use_gpi = False

# environment parameters
state_space=10000
action_space=4
state_dim=10
phi_dim=10
n_tasks=2
# task=0
n_steps=20
c = args.c

class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)
    
    def put(self, transition):
        self.buffer.append(transition)
    
    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []
        
        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
               torch.tensor(done_mask_lst)
    
    def size(self):
        return len(self.buffer)

    def reset(self):
        self.buffer = collections.deque(maxlen=buffer_limit)

class Qnet(nn.Module):
    def __init__(self, state_dim, action_space):
        super(Qnet, self).__init__()
        self.state_dim = state_dim
        self.action_space = action_space
        self.hidden_dim = 8

        # model layers
        self.fc1 = nn.Linear(self.state_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc3 = nn.Linear(self.hidden_dim, self.action_space)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
      
    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0, self.action_space-1)
        else : 
            return out.argmax().item()
            
def train(q, q_target, memory, optimizer):
    dqn_loss = 0
    its=10
    for i in range(its):
        s,a,r,s_prime,done_mask = memory.sample(batch_size)

        # #DEBUG
        # print('s', s.requires_grad)
        # print('a', a.requires_grad)
        # print('s_prime', s_prime.requires_grad)
        # print('done_mask', done_mask.requires_grad)

        q_out = q(s)

        q_a = q_out.gather(1,a)
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
        target = r + gamma * max_q_prime * done_mask
        loss = F.mse_loss(q_a, target) #smooth_l1_loss
        # print(loss.item())

        dqn_loss += loss.detach().item()

        # DEBUG
        optimizer.zero_grad()
        # print('\ngradient before backward')
        # _check_agent_grad(q)
        loss.backward()
        # print('\ngradient after backward')
        # _check_agent_grad(q)
        # print('\nweights before step')
        # _check_agent_weights(q)
        optimizer.step()
        # print('\nweights after step')
        # _check_agent_weights(q)
        # print()
    return dqn_loss/its

# DEBUG
def _check_agent_grad(model):
    for name, p in model.named_parameters():
        try:
            print(name, 'grad', p.grad.data)
        except:
            print(None)

def _check_agent_weights(model):
    for name, p in model.named_parameters():
        try:
            print(name, 'weight', p.data)
        except:
            print(None)

def _check_agent_require_grad(model):
    for name, p in model.named_parameters():
        try:
            print(name, 'weight', p.requires_grad)
        except:
            print(None)


def main():
    torch.manual_seed(seed)
    env = synthetic_env(state_space=state_space, 
                 action_space=action_space, 
                 state_dim=state_dim,
                 phi_dim=phi_dim, 
                 gamma=gamma,
                 n_tasks=n_tasks,
                 seed=0, c=c,
                 tildeP=False)
    q = Qnet(state_dim=state_dim, action_space=action_space)
    q_target = Qnet(state_dim=state_dim, action_space=action_space)
    q_target.load_state_dict(q.state_dict())
    memory = ReplayBuffer()

    optimizer = optim.Adam(q.parameters(), lr=dqn_lr)

    for task in range(n_tasks):
        if task==0:
            continue

        if task==1:
            phi_name=f"phi_{gamma}_{state_space}_{state_dim}_{action_space}_{phi_dim}_{phi_train_num_epi}_0_eps_0"
            env.phi.load_state_dict(torch.load(phi_name))
        
        score = 0.0 
        cum_score = 0.0
        tot_step_count = 0
        memory.reset()

        for n_epi in range(num_epi):
            # epsilon = 0
            epsilon = max(0.0, 0.5 - 0.5*(n_epi/200))
            # epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
            s, _ = env.reset()
            done = False

            step_count = 0
            while not done and step_count < n_steps:
                a = q.sample_action(torch.from_numpy(s).float(), epsilon)
                s_prime, r, phi, done, info = env.step(a, task) 
                # print('a:', a, 's_prime:', s_prime.dtype, 'r:', r, 'done:', type(done), 'info:', type(info))
                done_mask = 0.0 if done else 1.0
                memory.put((s,a,r,s_prime, done_mask))
                s = s_prime

                score += r
                step_count += 1

                if done:
                    break
                
            if memory.size()>64:
                dqn_loss = train(q, q_target, memory, optimizer)
                # print("task : {}, n_episode : {}, score : {:.4f}, cum. score : {:.4f}, phi loss : {:.4f}, dqn loss : {:.4f}, gpi percent : {:.2f}%, n_buffer : {}, eps : {:.1f}%".format(
                                                                # task, n_epi, score/(print_interval), cum_score/(n_epi), 0, dqn_loss, 0, memory.size(), epsilon*100))

            if n_epi%print_interval==0:
                if n_epi==0:
                    print("task : {}, n_episode : {}, score : {:.4f}, cum. score : {:.4f}, phi loss : {:.4f}, sf loss : {:.4f}, gpi percent : {:.2f}%, n_buffer : {}, eps : {:.1f}%".format(
                                                                task, n_epi, score, score, 0., 0., 0., memory.size(), epsilon*100))
                    cum_score += score
                    score = 0.0
                else:
                    cum_score += score
                    q_target.load_state_dict(q.state_dict())
                    print("task : {}, n_episode : {}, score : {:.4f}, cum. score : {:.4f}, phi loss : {:.4f}, dqn loss : {:.4f}, gpi percent : {:.2f}%, n_buffer : {}, eps : {:.1f}%".format(
                                                                task, n_epi, score/(print_interval), cum_score/(n_epi+1), 0, dqn_loss, 0, memory.size(), epsilon*100))
                    score = 0.0
        if task==1:
            # pass
            dqn_name=f"dqn_{c}-{gamma}-{seed}_eps_0.5"
            torch.save(q.state_dict(), dqn_name)
    env.close()

if __name__ == '__main__':
    main()